# -*- coding: utf-8 -*-
# !/usr/bin/env python
"""
-------------------------------------------------
   File Name：     utils
   Description :   
   Author :       lth
   date：          2022/8/3
-------------------------------------------------
   Change Activity:
                   2022/8/3 18:23: create this script
-------------------------------------------------
"""
__author__ = 'lth'
import torch
from torch import nn


def weights_init(net, init_type='normal', init_gain=0.001):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)

    print('initialize network with %s type' % init_type)
    net.apply(init_func)


class OhemCELoss(nn.Module):
    def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
        self.n_min = n_min
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(reduction='mean')

    def forward(self, logits, labels):
        # N, C, H, W = logits.size()

        loss = self.criteria(logits, labels)

        # loss, _ = torch.sort(loss, descending=True)
        # if loss[self.n_min] > self.thresh:
        #     loss = loss[loss>self.thresh]
        # else:
        #     loss = loss[:self.n_min]
        return loss

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    @staticmethod
    def calculate(pred, target):
        # intersection = torch.sum(pred[target > 0])
        # union = torch.sum(pred) + torch.sum(target > 0)
        # return 1 - (2 * intersection + 1) / (union + 1)
        lb_one_hot = torch.zeros_like(pred).scatter_(1, target.unsqueeze(1), 1)

        probs = torch.sigmoid(pred)
        numer = torch.sum((probs * lb_one_hot), dim=(2, 3))
        denom = torch.sum(probs.pow(1) + lb_one_hot.pow(1), dim=(2, 3))

        numer = torch.sum(numer, dim=1)
        denom = torch.sum(denom, dim=1)
        smooth = 1
        loss = 1 - (2 * numer + smooth) / (denom + smooth)

        return loss.mean()